Skip to content

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713

Open
cspades wants to merge 13 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp
Open

Add DCP compatibility for FSDP2-TP sharding in TransformerEngine.#2713
cspades wants to merge 13 commits intoNVIDIA:mainfrom
cspades:cye/fsdp2-tp-dcp

Conversation

@cspades
Copy link
Member

@cspades cspades commented Feb 26, 2026

Summary

  • Support (H/F)SDP2 x TP strided sharding, and DTensor FP8 parameters for Torch DCP checkpointing, across all TransformerEngineBaseModule(s).
    • Except GroupedLinear, pending FSDP2 standalone pipe-cleaning. All other modules under transformer_engine.pytorch.modules are supported.
    • FusibleOperation support is also a WIP, except for LayerNorm or RMSNorm which are TE modules.
  • Associated with BioNeMo-Recipes Llama3 TP: Enable TransformerEngine-backed Tensor Parallelism with Llama3. bionemo-framework#1483
    • Notably, TransformerEngine TP can be easily mixed with DTensor-based TP when unified by Torch DCP! In the Llama3 recipe, we use DTensor-based TP on the torch.nn.Embedding, TransformerEngine-based TP on the LM head, and weight-tie the LM head to the torch.nn.Embedding, which is why we do not need to call set_device_mesh for the LM head!
  • Credit to @pstjohn for coming up with this idea!

Usage / Documentation

(tp_mesh and weight_mesh can also be passed in TEModule.__init__.)

    def set_device_mesh(
        self,
        tp_mesh: Optional[DeviceMesh] = None,
        weight_mesh: Optional[DeviceMesh] = None,
    ) -> None:
        """
        Set DeviceMesh(s) used for sharding weights and convert main weights into DTensor
        depending on the TransformerEngine class to support FSDP-TP sharding with FSDP2.

        TransformerEngine manages tensor parallel mechanics, while DTensor offers seamless
        integration with Torch DCP checkpointing. This method should only be invoked when
        using DTensor parameters, e.g. when using FSDP2 or DCP.

        When FSDP2 fully_shard() encounters any DTensor Shard(s), it will automatically
        convert them into FSDP-TP strided or non-strided shards depending on the current
        sharding dimension and factor of the DTensor. When the sharding dimension of FSDP
        matches that of TP, FSDP uses a _StridedShard placement type instead of Shard.
        This experimental FSDP-TP logic presides in this FSDP2 initialization function:
        ``torch.distributed.fsdp._fully_shard._fsdp_param._init_sharded_param``

        Parameters
        ----------
        tp_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a TP mesh dimension, e.g. device_mesh["tp"].
            Only required when using TP with DTensor parameters, e.g. for FSDP2 or DCP.
        weight_mesh : Optional[DeviceMesh]
            A 1-D DeviceMesh containing a weight-sharding mesh dimension. Only required
            when using the FP8 Current (per-tensor) Scaling recipe on sharded DTensor
            parameters and if the DTensor DeviceMesh includes dimensions that do not
            shard weights, such as in the case of HSDP (DP-Replicate x DP-Shard).
            For example:
                - device_mesh["dp"] for FSDP.
                - device_mesh["dp_cp"] if using CP ranks in FSDP.
                - device_mesh["dp_shard"] if using HSDP ("dp_replicate", "dp_shard").
                - device_mesh["tp"] if using TP.
                - device_mesh["dp_cp_tp"] if strided-sharding with FSDP-TP.
        """

Details

DTensor Lifecycle in TransformerEngine

  • Initialization
    • __init__
      • TransformerEngine model parameters are initialized either on device or meta device with the appropriate tp_size and TP sharding strategy, e.g. parallel_mode and sequence_parallel.
    • TransformerEngineModule.set_device_mesh(tp_mesh, weight_mesh)
      • Converts parameters to DTensor with appropriate TP placement(s) based on the TP sharding strategy specified in __init__, using transformer_engine.pytorch.distributed._convert_param_to_dtensor_param.
        • tp_mesh is a 1-D DeviceMesh containing the TP ProcessGroup that will be registered with the TransformerEngine module.
        • weight_mesh is the 1-D DeviceMesh containing the ProcessGroup that shards TransformerEngine module weights, the flattened combination of groups such as FSDP and TP. Specifically, it excludes non-weight groups such as DP-Replicate when using HSDP or HSDP-TP and is mainly required for per-Tensor scaling recipes like Float8CurrentScaling.
      • Needs to be invoked prior to fully_shard (which responds to the TP placements) and prior to reset_parameters(defer_init=False), which quantizes parameters.
      • Can also be directly invoked during __init__(tp_mesh, weight_mesh) for supported TransformerEngine modules.
    • fully_shard shards the TransformerEngine model with FSDP2.
      • When fully_shard encounters TP sharding on dim=0, it will use a _StridedShard for DP. Put simply, this "pre-shards" the data prior to sharding on the current placement, followed by concatenating the pre-shards to get strided shards that will be re-sharded by the next placement. This effectively reverses the sharding order when processing the placements from left-to-right, and distributes shards as if we sharded on TP first, then FSDP, as required, even though DP appears before TP in the DeviceMesh and DTensor.placements. (See Appendix for visualization of this sharding strategy.)
    • reset_parameters is called if using meta device initialization.
  • Training
    • Pre-forward, FSDP2 all-gathers the sharded DTensor "main" weight that it registered during fully_shard. (Note that this essentially shares the same properties as the compute weight besides shape, and supporting tools such as FusedAdam must be used to properly handle high-precision main weights.)
      • When using FSDP2 x TP, the all-gathered Tensor is actually a TP-sharded DTensor, which deviates from the original FSDP2 paradigm where the all-gathered Tensor is fully-unsharded and the DTensor wrapping is discarded. To support these DTensor compute weights in TransformerEngine modules, we utilize transformer_engine.pytorch.distributed._extract_trainable_tensor_from_dtensor to localize the DTensor and also inherit requires_grad attribute from the DTensor parameter as the local Tensor has this un-set during DTensor.from_local(Tensor) for FP8 parameters specifically!
    • Post-backward, the Tensor gradient is converted to DTensor and attached to the DTensor.grad attribute. Handled by DTensor <> Tensor Autograd conversion functions, and in the case of FusibleOperation, casted during the backward implementation.

QuantizedTensor Storage

  • When both row and column data are None, we send untyped_storage() to a default 1-byte storage that unblocks DCP checkpoint loading assertions using this as a definition for "emptiness". This is because a storage of 0 bytes is a data_ptr() = nullptr and breaks DCP.
    • While untyped_storage is not used anywhere in TransformerEngine, it may break code that uses Storage to figure out if a Tensor is empty or not, as now QuantizedTensor storage will always be a 1-byte storage even when both row and column data are not set. Those checks would instead need to compare the storage size against 1 byte instead of 0 bytes.

Bugs

  • Fix bug where "shard" was the presumed weight sharding sub-mesh in the DTensor.device_mesh. Now, users can precisely specify their own custom weight-sharding DeviceMesh for per-tensor amax_reduction_group via the set_device_mesh(weight_mesh) API.
  • TransformerEngineBaseModule: self.quantizers = {"scaling_fwd": [], "scaling_bwd": []}

Testing

# TransformerEngine Main
[Rank 0] (after 1 iterations) memory (MB) | allocated: 23511.65 | max allocated: 25189.68 | reserved: 25678.00 | max reserved: 25678.00
 [2026-03-02 09:55:17.189564] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12715.7 | throughput per GPU (TFLOP/s/GPU): 530.6 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124915E+00 | loss scale: 1.0 | grad norm: 5.474 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:55:29.768521] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12578.7 | throughput per GPU (TFLOP/s/GPU): 536.4 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.143806E+00 | loss scale: 1.0 | grad norm: 5.366 | number of skipped iterations:   0 | number of nan iterations:   0 |

# Post-DCP Modifications (This PR)
[Rank 0] (after 2 iterations) memory (MB) | allocated: 23511.65 | max allocated: 29783.24 | reserved: 25678.00 | max reserved: 31510.00
 [2026-03-02 09:29:36.550070] iteration       99/15258789 | consumed samples:        12672 | elapsed time per iteration (ms): 12556.5 | throughput per GPU (TFLOP/s/GPU): 537.3 | learning rate: 4.866046E-07 | global batch size:   128 | lm loss: 1.124463E+00 | loss scale: 1.0 | grad norm: 5.471 | number of skipped iterations:   0 | number of nan iterations:   0 |
 [2026-03-02 09:29:49.216068] iteration      100/15258789 | consumed samples:        12800 | elapsed time per iteration (ms): 12665.7 | throughput per GPU (TFLOP/s/GPU): 532.7 | learning rate: 4.915198E-07 | global batch size:   128 | lm loss: 1.142863E+00 | loss scale: 1.0 | grad norm: 5.355 | number of skipped iterations:   0 | number of nan iterations:   0 |
  • NOTE(@cspades): DelayedScaling has DCP save/load disparity issues, i.e. on the scale of +/-1 to the uint8 parameter checkpoint!

Appendix

_StridedShard - Using FSDP2 x TP Strided-Sharding

# (DP=4, TP=2)
(_StridedShard(dim=0, sf=2), Shard(dim=0))

┌───┬───┐
│ 0 │ 4 │ ← DP=0
├───┼───┤
│ 1 │ 5 │ ← DP=1
├───┼───┤          FSDP all-gather happens across the DP ranks,
│ 2 │ 6 │ ← DP=2   so we need to form the 0-3 and 4-7 TP shards!
├───┼───┤
│ 3 │ 7 │ ← DP=3
└───┴───┘
  ↑   ↑
TP=0 TP=1

When redistribute'ing a global DTensor to (_StridedShard(dim=0, sf=2), Shard(dim=0)), DTensor will perform the following steps:

  • Pre-shard the Tensor data with respect to the stride / shard factor, which is defined as the product of the parallelism sizes of all Shard placements to the right of _StridedShard. (In the above example, since TP=2, the factor is 2.)
    • [0 1 2 3 4 5 6 7] -> [0 1 2 3] and [4 5 6 7].
    • In the context of this PR and fully_shard, this has already been done via initializing the TransformerEngine module with TP and calling _convert_param_to_dtensor_param!
  • Shard the pre-shards for _StridedShard.
    • [0] [1] [2] [3] and [4] [5] [6] [7]
  • Concatenate the strided shards.
    • [0 4] [1 5] [2 6] [3 7], which are assigned to the _StridedShard ranks.
    • Note that this is very different if we did left-to-right-sharding, which would have given us [0 1] [2 3] [4 5] [6 7]!
  • Subsequently / finally, each strided shard is sharded on the Shard placement.
    • [0] [4] / [1] [5] / [2] [6] / [3] [7], which are assigned to the Shard ranks.
    • Note that this is very different if we did left-to-right sharding, which would have given us [0] [1] / [2] [3] / [4] [5] / [6] [7]!

PyTorch also supports the inverse / un-sharding of this redistribute, which is literally the inverse of these simple operations! (Though things get a bit more complicated with un-even shards from odd-numbered dimension sizes.)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR adds DTensor-based DCP checkpoint compatibility for all TransformerEngineBaseModules when used with (H/F)SDP2 × TP strided sharding. It introduces a set_device_mesh(tp_mesh, weight_mesh) API on every TE module, a new _convert_param_to_dtensor_param utility, and an identity-preserving _ToLocalIdentity autograd function that ensures FSDP2's in-place post-all-gather updates remain visible through the stored ctx reference. The approach is architecturally sound — TE continues to own TP collectives while DTensor provides the DCP-compatible sharding metadata, allowing clean interop with torch.distributed.checkpoint.

Key changes and callouts:

  • set_device_mesh(tp_mesh, weight_mesh) added to Linear, LayerNormLinear, LayerNormMLP, DotProductAttention, MultiheadAttention, TransformerLayer, LayerNorm, and RMSNorm. Each method converts parameters to appropriately-placed DTensors and wires up amax_reduction_group for Float8CurrentScaling.
  • quantizers default fixed from {}[] in base.py to support integer-indexed access already used throughout the quantizer pipeline.
  • Incidental backward bug fix in _LayerNormMLP: the guard isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) was always False (a Quantizer is never a QuantizedTensorStorage), making ctx.fc1_weight.update_usage(columnwise_usage=True) dead code. The fix changes the subject to ctx.fc1_weight.
  • _default_storage (1-byte UntypedStorage) added to all block-format quantized tensor types to satisfy DCP's non-null storage assertion on empty tensors.
  • Test suite gains a full DCP save/load round-trip with model and optimizer state parity verification, plus a standalone run_fsdp2_allgather.py for FP8 all-gather correctness.
  • Known deferred items: GroupedLinear support, FusibleOperation support, DTensor bias registration when use_bias=False, and DelayedScaling minor uint8 ±1 DCP parity disparity.

Confidence Score: 3/5

  • Mostly safe to merge with awareness of known deferred issues; the main risk is the incidental _LayerNormMLP backward bug fix changing established behavior.
  • The DTensor conversion logic and _ToLocalIdentity design are sound and backed by Megatron parity tests. However: (1) the _LayerNormMLP backward update_usage fix changes long-standing (if broken) behavior and should get explicit test coverage; (2) args.sharding_dims None-guard in run_fsdp2_model.py is still unresolved (acknowledged in previous threads); (3) the private DTensor._local_tensor API is used in three places without a stabilization note; (4) _default_storage allocates with torch.cuda.current_device() which could be wrong in edge cases; (5) GroupedLinear and FusibleOperation support are explicitly deferred, leaving parts of the TE surface unsupported.
  • transformer_engine/pytorch/module/layernorm_mlp.py (backward bug fix), tests/pytorch/distributed/run_fsdp2_model.py (None guard for sharding_dims), transformer_engine/pytorch/tensor/storage/*.py (_default_storage device)

Important Files Changed

Filename Overview
transformer_engine/pytorch/distributed.py Adds _convert_param_to_dtensor_param (converts plain params to DTensor) and _ToLocalIdentity / _extract_trainable_tensor_from_dtensor (identity-preserving DTensor→local extraction for FSDP2 in-place updates). The _ToLocalIdentity.forward intentionally uses the private DTensor._local_tensor attribute — stable in practice but worth tracking.
transformer_engine/pytorch/module/base.py Fixes quantizers default from {} to [] (needed for integer-indexed access), localizes DTensor inputs before passing to TE C++ kernels, and improves _quantize_and_configure_parameter to fall back to device_mesh.get_group() for amax reduction when weight_mesh is not explicitly provided. Changes look correct.
transformer_engine/pytorch/module/linear.py Adds set_device_mesh, _get_bias_tensors, and _set_tensor_parallel_attributes. The old TP-attribute setting logic is refactored out of reset_parameters into _set_tensor_parallel_attributes. _get_weight_and_bias_tensors return type is changed from Tuple to List — callers should be verified for tuple-unpacking compatibility.
transformer_engine/pytorch/module/layernorm_mlp.py Adds full set_device_mesh and _get_bias/layernorm_weight_and_bias helpers. Contains a notable incidental bug fix: the backward isinstance check for update_usage was guarded by ctx.fc1_weight_quantizer (always False) and is corrected to ctx.fc1_weight. Unconditional fc1_bias / fc2_bias DTensor conversion when use_bias=False is a known deferred issue.
transformer_engine/pytorch/module/layernorm_linear.py Parallel structure to layernorm_mlp.py; adds set_device_mesh, _get_bias_tensors, _get_layernorm_weight_and_bias, and _set_tensor_parallel_attributes. Row-parallel placement correctly uses Shard(dim=0) for the LayerNorm parameters (replicated across TP in column-parallel, sequence-sharded in row-parallel). Logic looks sound.
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py Adds _default_storage (1-byte UntypedStorage) to unblock DCP assertion on empty tensors. Allocation uses torch.cuda.current_device() which assumes CUDA is initialized and the current device is correct — correct for typical TE GPU usage but potentially fragile in CPU-only test environments.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds untyped_storage override and _default_storage; docstring incorrectly says "MXFP8Tensor" (copy-paste from mxfp8_tensor.py). Logic is otherwise correct and mirrors MXFP8/Float8Blockwise implementations.
tests/pytorch/distributed/run_fsdp2_model.py Major addition of DCP checkpoint save/load round-trip testing, HSDP-TP mesh construction, and full parity verification of model/optimizer state. args.sharding_dims can be None (nargs="+", not required), causing a TypeError in the new '_'.join(...) CKPT_DIR construction — flagged in previous threads. Otherwise the test logic is solid.

Sequence Diagram

sequenceDiagram
    participant User
    participant TEModule as TE Module (__init__)
    participant SDM as set_device_mesh()
    participant FSDP2 as fully_shard()
    participant RP as reset_parameters()
    participant FWD as forward()
    participant DCP as Torch DCP

    User->>TEModule: __init__(tp_mesh, weight_mesh)
    TEModule->>TEModule: init_fp8_metadata()
    TEModule->>SDM: set_device_mesh(tp_mesh, weight_mesh)
    SDM->>SDM: _convert_param_to_dtensor_param()<br/>plain param → DTensor(Shard/Replicate)
    SDM->>SDM: set amax_reduction_group<br/>on Float8CurrentScalingQuantizer
    SDM-->>TEModule: params are now DTensors
    TEModule->>RP: reset_parameters(defer_init=device=="meta")
    RP->>RP: _set_tensor_parallel_attributes()

    User->>FSDP2: fully_shard(model, mesh[dp_dims])
    FSDP2->>FSDP2: detects DTensor Shard(dim=0)<br/>→ uses _StridedShard for DP-TP overlap
    FSDP2-->>User: model params are FSDP-sharded DTensors

    Note over User,FSDP2: If meta device: call reset_parameters() now

    loop Training Step
        FSDP2->>FWD: all-gather DTensor shards → TP-sharded DTensor
        FWD->>FWD: _extract_trainable_tensor_from_dtensor()<br/>_ToLocalIdentity preserves object identity
        FWD->>FWD: TE C++ kernels on local Tensor
        FWD-->>FSDP2: grad → DTensor.grad via _ToLocalIdentity.backward
    end

    User->>DCP: save({"app": AppState(model, optimizer)})
    DCP->>DCP: AppState.state_dict()<br/>evict _extra_state, clear empty optim states
    DCP-->>User: checkpoint written

    User->>DCP: load(state_dict, checkpoint_id)
    DCP->>DCP: AppState.load_state_dict()<br/>set_state_dict(strict=False)
    DCP-->>User: model restored to pre-save state
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py, line 63 (link)

    _default_storage device allocation may race in multi-GPU processes

    torch.cuda.current_device() is called at __new__ time to create the default 1-byte storage. When multiple QuantizedTensor instances are created during model construction across CUDA streams, current_device() is guaranteed to return the correct device per-process in standard distributed setups. However, if this constructor is ever triggered on a CPU-only system (e.g., in unit tests without GPUs) or before a CUDA context is initialized, it will either fail or silently produce a CUDA-device storage that mismatches the tensor's actual device.

    A more robust approach would be to use device=self.device (resolved from rowwise_data or columnwise_data if available, otherwise a CPU default):

    _device = (rowwise_data or columnwise_data).device if (rowwise_data or columnwise_data) else torch.device("cpu")
    instance._default_storage = torch.UntypedStorage(1, device=_device)

    The same pattern is used identically in mxfp8_tensor_storage.py (line 101) and nvfp4_tensor_storage.py (line 126).

  2. transformer_engine/pytorch/module/layernorm_mlp.py, line 1371-1374 (link)

    Bug fix: backward update_usage guard was always false before this change

    The original condition isinstance(ctx.fc1_weight_quantizer, QuantizedTensorStorage) was checking if the quantizer object is a QuantizedTensorStorage — but a quantizer is a Quantizer subclass, never a QuantizedTensorStorage. This meant the guard was always False, so ctx.fc1_weight.update_usage(columnwise_usage=True) was dead code and was never called during backward.

    The fix correctly changes the subject of the isinstance check to ctx.fc1_weight. This is a real bug fix bundled into this PR that re-enables columnwise usage tracking on the FC1 weight during the backward pass for block-format quantized tensors. It's worth highlighting in the PR description and/or a CHANGELOG entry so reviewers don't confuse it with a refactor.

Last reviewed commit: 82780a1

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from 4ec2947 to dbb9d14 Compare March 4, 2026 18:10
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch from fcdd5bd to c912f5b Compare March 5, 2026 16:06
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 5 times, most recently from bc82f02 to 267f1df Compare March 10, 2026 01:30
@vthumbe1503
Copy link
Collaborator

/te-ci L1 pytorch

@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 4 times, most recently from f0b3cae to af7362a Compare March 12, 2026 15:26
@cspades cspades force-pushed the cye/fsdp2-tp-dcp branch 4 times, most recently from 9435382 to 15df86f Compare March 16, 2026 19:16
@cspades
Copy link
Member Author

cspades commented Mar 16, 2026

/te-ci L1 pytorch

cspades and others added 12 commits March 17, 2026 08:41
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…ess.

Signed-off-by: Cory Ye <cye@nvidia.com>
… are still model parity tested.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades
Copy link
Member Author

cspades commented Mar 17, 2026

Signed-off-by: Cory Ye <cye@nvidia.com>
@cspades
Copy link
Member Author

cspades commented Mar 17, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants